#!/usr/bin/env python

#
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
#


import argparse
import os.path

from claragenomics.bindings import cuda
from claragenomics.bindings import cudapoa
from claragenomics.io.utils import read_poa_group_file

"""
Sample program for using CUDAPOA Python API for consensus generation.
"""


def run_cudapoa(groups, msa, print_output):
    """
    Generate consensus for POA groups.

    Args:
        groups : A list of groups (i.e. list of sequences) for which
                 consensus is to be generated.
        msa : Whether to generate MSA or consensus.
        print_output : Print output MSA or consensus.
    """
    # Get avaialble memory information
    free, total = cuda.cuda_get_mem_info(cuda.cuda_get_device())

    # Create batch
    max_sequences_per_poa = 100
    gpu_mem_per_batch = 0.9 * free
    batch = cudapoa.CudaPoaBatch(max_sequences_per_poa,
                                 gpu_mem_per_batch,
                                 stream=None,
                                 output_type=("msa" if msa else "consensus"))

    # Add poa groups to batch
    initial_count = 0
    for i, group in enumerate(groups):
        group_status, seq_status = batch.add_poa_group(group)
        if (group_status == 0):
            pass

        # Once batch is full, run POA processing
        if ((group_status == 1) or (i == len(groups) - 1)):
            batch.generate_poa()

            if msa:
                msa, status = batch.get_msa()
                if print_output:
                    print(msa)
            else:
                consensus, coverage, status = batch.get_consensus()
                if print_output:
                    print(consensus)

            print("Processed group {} - {}".format(initial_count, i))
            initial_count = i
            batch.reset()
            group_status, seq_status = batch.add_poa_group(group)


def parse_args():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(
        description="CUDAPOA Python API sample program.")
    parser.add_argument('-m',
                        help="Run MSA generation. By default consensusis generated.",
                        action='store_true')
    parser.add_argument('-p',
                        help="Print output MSA or consensus for each POA group.",
                        action='store_true')
    return parser.parse_args()


if __name__ == "__main__":

    # Read cmd line args.
    args = parse_args()

    # Create input dataset.
    cwd = os.path.dirname(os.path.abspath(__file__))
    cga_root = os.path.dirname(os.path.dirname(cwd))
    cudapoa_data_dir = os.path.join(cga_root, "cudapoa", "data")
    sample_windows = os.path.join(cudapoa_data_dir, "sample-windows.txt")
    groups = read_poa_group_file(sample_windows, 1000)

    # Run CUDAPOA.
    run_cudapoa(groups, args.m, args.p)
